import pandas as pd
import numpy as np
from scipy.stats import norm as norm
import copy
from src.models_new import simulation, surplus, prices, \
    first_stage_utils, model_objects, value_search
from sklearn.linear_model import LinearRegression as Ols
from sklearn.linear_model import LogisticRegression as Logit
from sklearn.preprocessing import PolynomialFeatures
from dask.distributed import Client
from numba import njit
import itertools

def get_value_function_components_from_sim_wrapper(
        t_i_tuple, prob_new_match_by_ym, prob_extension_by_ym, state_simple_by_ym,
        search_grid_by_spec, search_value_by_spec, const, r, sigma,
        params, prob_extend_by_spec, share_available_by_spec,
        extension_denom_by_tau_by_spec, n_rigs):

    (t, i) = t_i_tuple
    print(t_i_tuple)

    mri_by_node = np.linspace(0, 2.15, 30)
    components_by_ym_i = dict()
    for spec in ['low', 'mid', 'high']:
        for tau in [2, 3, 4]:
            components_by_ym_i[f'n_match_{spec}_{tau}'] \
                = prob_new_match_by_ym[i][spec][tau].sum() * n_rigs[spec]

            # Get the prices for each mri node
            # Update params
            params['m_0'] = params[f'm_0_{spec}']
            params['m_1'] = params[f'm_1_{spec}']
            price_by_mri = list()
            for mri in mri_by_node:
                price_by_mri.append(
                    prices.get_price_all(
                        mri=mri,
                        tau=tau,
                        state=state_simple_by_ym[t],
                        grid=search_grid_by_spec[spec],
                        values=search_value_by_spec[spec],
                        const=const.values.T[0],
                        r=r.values,
                        sigma=sigma.values[0][0],
                        prob_extend=prob_extend_by_spec[spec],
                        params=params,
                        seeds=range(1)
                    )[0]
                )
            price_by_mri = np.array(price_by_mri)

            # Get the average price
            components_by_ym_i[f'price_new_{spec}_{tau}'] \
                = (price_by_mri * prob_new_match_by_ym[i][spec][tau]).sum() \
                  / prob_new_match_by_ym[i][spec][tau].sum()

        components_by_ym_i[f'n_no_match_{spec}'] \
            = (
                  share_available_by_spec[i][spec]
                  - prob_new_match_by_ym[i][spec][2].sum()
                  - prob_new_match_by_ym[i][spec][3].sum()
                  - prob_new_match_by_ym[i][spec][4].sum()
              ) * n_rigs[spec]
    return components_by_ym_i


def get_value_function_components_from_sim(inputs, options):

    parallel_list = [(t, i) for t, i in enumerate(inputs['prob_new_match_by_ym'])]
    client = Client(
        threads_per_worker=options['threads_per_worker'],
        n_workers=options['n_workers']
    )

    print("Test")
    #futures = client.map(
     #   get_value_function_components_from_sim_wrapper,
      #  parallel_list,
       # **inputs
    #)

    futures = dict()
    for tuple_t_i in parallel_list:
        futures[tuple_t_i[0]] = get_value_function_components_from_sim_wrapper(
            tuple_t_i,
            **inputs
        )

    #progress(futures)

    components_by_ym = dict()
    for (t, i) in parallel_list:
        components_by_ym[i] = futures[t] #.result()

    return components_by_ym


def get_first_stage(components_by_ym, state, polynomial_order):
    df = pd.DataFrame(components_by_ym).T
    df_n_matches = dict()
    df_n_extend = dict()

    for spec in ['low', 'mid', 'high']:
        for tau in [2, 3, 4]:
            df_n_matches[(spec, tau - 1)] = pd.Series(
                df[f'n_match_{spec}_{tau}']
                    .index
                    .repeat(df[f'n_match_{spec}_{tau}'].astype(int))
                    .format()
            )
            df_n_extend[(spec, tau - 1, 1)] = pd.Series(
                df[f'n_extension_{spec}_{tau}']
                    .index
                    .repeat(df[f'n_extension_{spec}_{tau}'].astype(int))
                    .format()
            )
            df_n_extend[(spec, tau - 1, 0)] = pd.Series(
                df[f'n_extension_{spec}_{tau}']
                    .index
                    .repeat(df[f'n_no_extension_{spec}_{tau}'].astype(int))
                    .format()
            )
        df_n_matches[(spec, 0)] = pd.Series(
            df[f'n_no_match_{spec}']
                .index
                .repeat(df[f'n_no_match_{spec}'].astype(int))
                .format()
        )

    # Make the dataframes
    df_mnl = (
        pd.concat(df_n_matches)
            .reset_index()
            .rename(columns={'level_0': 'rig_spec', 'level_1': 'tau', 0: 'date'})
            .drop(columns=['level_2'])
    )

    df_extend = (
        pd.concat(df_n_extend)
            .reset_index()
            .rename(
            columns={'level_0': 'rig_spec', 'level_1': 'tau', 'level_2': 'extend',
                     0: 'date'}
        )
    )

    # Get state dataframe
    df = df.reset_index().rename(columns={'index': 'date'})

    # Ensure all dates are in datetime format
    df_mnl['date'] = pd.to_datetime(df_mnl['date'])
    df_extend['date'] = pd.to_datetime(df_extend['date'])

    # %% Get df in a format to do dayrate and extension regressions
    # Do polynomial data
    poly = PolynomialFeatures(polynomial_order, include_bias=False)

    # Setup state names
    state_names = ['g', 'n_l', 'n_m', 'n_h']

    # Get data in the polynomial form
    poly_fit = poly.fit_transform(state[state_names])
    poly_names = poly.get_feature_names(state_names)

    state_poly = pd.DataFrame(poly_fit,
                              columns=poly_names,
                              index=state['date'])

    df_match_ols_by_spec = dict()
    df_extension_by_spec = dict()
    for spec in ['low', 'mid', 'high']:
        df_match_ols_by_spec[spec] = (
            pd.melt(
                df[['date', f'price_new_{spec}_2', f'price_new_{spec}_3',
                    f'price_new_{spec}_4']],
                value_vars=[f'price_new_{spec}_2', f'price_new_{spec}_3',
                            f'price_new_{spec}_4'],
                id_vars=['date'],
                var_name='tau',
                value_name='day_rate'
            ).replace({f"price_new_{spec}_2": 2, f"price_new_{spec}_3": 3,
                       f"price_new_{spec}_4": 4})
                .merge(
                state_poly,
                on='date',
                how='left'
            )
        )
        df_extension_by_spec[spec] = (
            pd.melt(
                df[['date', f'price_extend_{spec}_2', f'price_extend_{spec}_3',
                    f'price_extend_{spec}_4']],
                value_vars=[f'price_extend_{spec}_2', f'price_extend_{spec}_3',
                            f'price_extend_{spec}_4'],
                id_vars=['date'],
                var_name='tau',
                value_name='day_rate'
            )
            .replace({f"price_extend_{spec}_2": 2, f"price_extend_{spec}_3": 3,
                      f"price_extend_{spec}_4": 4})
            .merge(
                state_poly,
                on='date',
                how='left'
            )
        )
    df_extend = df_extend.merge(
        state_poly,
        on='date',
        how='left'
    )

    # %% GET POLICY FUNCTIONS -----------------------------------------------------------
    (
        contracts_poly, mnl_poly, poly_names, poly
    ) = first_stage_utils.setup_data(
        state, df_mnl, df, polynomial_order, contracts_have_state=False)

    # Do the multinomial logit policy functions
    output_mnl, reg_mnl = first_stage_utils.estimate_mnl(mnl_poly, poly_names, poly)

    # Get the other policy functions
    reg_ols = dict()
    for spec in ['low', 'mid', 'high']:
        reg_ols[spec] = Ols().fit(
            df_match_ols_by_spec[spec][poly_names + ['tau']],
            df_match_ols_by_spec[spec]['day_rate']
        )

    # %% SAVE THE OLS POLICY FUNCTIONS --------------------------------------------------
    results_all_ols = dict()
    for spec in ['low', 'mid', 'high']:
        for metric in ['day_rate']:
            results_all_ols[(spec, metric)] = (
                np.append(
                    np.array(reg_ols[spec].intercept_),
                    reg_ols[spec].coef_
                )
            )

    results_all_ols_df = pd.DataFrame(
        results_all_ols,
        index=['intercept'] + poly_names + ['tau']).T

    # %% SAVE THE MNL POLICY FUNCTIONS --------------------------------------------------
    results_all_mnl_coefs = dict()
    results_all_mnl_intercept = dict()
    for spec in ['low', 'mid', 'high']:
        for metric in ['tau']:
            results_all_mnl_coefs[(spec, 0)] = reg_mnl[f'{metric} {spec}'].coef_[0, :]
            results_all_mnl_coefs[(spec, 2)] = reg_mnl[f'{metric} {spec}'].coef_[1, :]
            results_all_mnl_coefs[(spec, 3)] = reg_mnl[f'{metric} {spec}'].coef_[2, :]

            # Add in something for very low prob. events
            try:
                results_all_mnl_coefs[(spec, 4)] = reg_mnl[f'{metric} {spec}'].coef_[3,
                                                   :]
            except:
                print("Could not find any obs. for", spec, 4)
                results_all_mnl_coefs[(spec, 4)] = 0.0

            results_all_mnl_intercept[(spec, 0)] = \
            reg_mnl[f'{metric} {spec}'].intercept_[0]
            results_all_mnl_intercept[(spec, 2)] = \
            reg_mnl[f'{metric} {spec}'].intercept_[1]
            results_all_mnl_intercept[(spec, 3)] = \
            reg_mnl[f'{metric} {spec}'].intercept_[2]
            try:
                results_all_mnl_intercept[(spec, 4)] = \
                reg_mnl[f'{metric} {spec}'].intercept_[3]
            except:
                results_all_mnl_intercept[(spec, 4)] = -30.0

    # Get the policy function coefficients into a nice form
    results_all_mnl_coefs_df = pd.DataFrame(
        results_all_mnl_coefs,
        index=poly_names).T

    results_all_mnl_intercept_df = pd.DataFrame(
        results_all_mnl_intercept,
        index=['intercept']).T

    results_all_mnl_df = pd.concat(
        [results_all_mnl_coefs_df, results_all_mnl_intercept_df], axis=1)
    results_all_mnl_df = results_all_mnl_df[['intercept'] + poly_names]

    # %% GET EXTENSIONS -----------------------------------------------------------------
    reg_extension_by_metric_spec = dict()
    for spec in ['low', 'mid', 'high']:
        # Get the extension price
        reg_extension_price = Ols().fit(df_extension_by_spec[spec][poly_names + ['tau']],
                                        df_extension_by_spec[spec]['day_rate'])

        reg_extension_by_metric_spec[(spec, 'price')] = np.append(
            np.array(reg_extension_price.intercept_),
            reg_extension_price.coef_
        )

        # Get extension probability
        reg_extension_prob = Logit(max_iter=10000).fit(
            df_extend[poly_names + ['tau']],
            df_extend['extend']
        )
        reg_extension_by_metric_spec[(spec, 'prob')] = np.append(
            np.array(reg_extension_prob.intercept_),
            reg_extension_prob.coef_
        )

    # Get the extension probability
    df_extension = pd.DataFrame(
        reg_extension_by_metric_spec,
        index=['intercept'] + poly_names + ['tau']).T

    # Finally, save them
    return df_extension, results_all_mnl_df, results_all_ols_df, df_extend, df_mnl


def do_intermediary_matching(state, n_enter_by_tau, mri_reversed, spec_mri_ordered, n_rigs):
    n_available_remain = copy.deepcopy({'low': state[1], 'mid': state[2], 'high': state[3]})
    wells_remain = copy.deepcopy(n_enter_by_tau)
    n_match = dict()

    # Do the matching
    for i in spec_mri_ordered:
        # Set the (ordered) spec and mri
        if spec_mri_ordered[i][0][1] == 0:
            spec = 'low'
        elif spec_mri_ordered[i][0][1] == 1:
            spec = 'mid'
        elif spec_mri_ordered[i][0][1] == 2:
            spec = 'high'

        mri_loc = spec_mri_ordered[i][0][0]

        # Do the matching
        for tau in [4, 3, 2]:
            if (n_available_remain[spec] <= 0.0):
                n_match[(spec, tau, mri_loc)] = 0.0
            else:
                n_match[(spec, tau, mri_loc)] = \
                    copy.deepcopy(min(n_available_remain[spec], wells_remain[tau][mri_loc]))
                n_available_remain[spec] += -n_match[(spec, tau, mri_loc)]
                wells_remain[tau][mri_loc] += -n_match[(spec, tau, mri_loc)]

    # Convert matches in to a prob. of matching
    prob_match_by_tau_spec = dict()
    for spec in ['high', 'mid', 'low']:
        for tau in [2, 3, 4]:
            prob_match_by_tau_spec[(tau, spec)] = list()
            for mri_loc in range(len(mri_reversed)):
                try:
                    prob_match_by_tau_spec[(tau, spec)].append(
                        n_match[(spec, tau, mri_loc)] / n_rigs[spec]
                    )
                except:
                    prob_match_by_tau_spec[(tau, spec)].append(0.0)
            prob_match_by_tau_spec[(tau, spec)] = np.array(
                prob_match_by_tau_spec[(tau, spec)]
            )

    return prob_match_by_tau_spec


def get_n_enter(entry_prob, data_g, params):
    """" From the probability of entering, produce the total number (i.e. the
    mass) of entered wells by date and contract length.
    """
    mri_max = 2.15

    draws = params['d_0'] + params['d_1'] * data_g
    params['denom'] = (
            norm.cdf(mri_max, loc=params['mu_0'], scale=params['sigma_0'])
            - norm.cdf(0, loc=params['mu_0'], scale=params['sigma_0'])
    )
    mri_array = np.linspace(0, mri_max, 30)
    f_mri_array = model_objects.potential_matches(
        mri_array, mu_0=params['mu_0'], mu_1=0.0, sigma_0=params['sigma_0'],
        sigma_1=1.0, weight_lambda=params['weight_lambda'], denom=params['denom'])
    params['p_4'] = 1 - params['p_3'] - params['p_2']

    f_discretized = f_mri_array / f_mri_array.sum()
    n_enter_by_ym_tau = dict()
    for t, ym in enumerate(entry_prob):
        n_enter_by_ym_tau[ym] = dict()
        for tau in [2, 3, 4]:
            n_enter_by_ym_tau[ym][tau] = (
                    draws[t]
                    * params[f'p_{tau}']
                    * entry_prob[ym][tau]
                    * f_discretized
            )
    return n_enter_by_ym_tau


def do_matching_with_cutoff(cutoff_low, cutoff_high, mri_reversed, state, n_enter_by_tau):
    n_available_remain = copy.deepcopy({'low': state[1], 'mid': state[2], 'high': state[3]})
    wells_remain = copy.deepcopy(n_enter_by_tau)
    n_match = dict()
    n_match_by_spec = {'low': 0.0, 'mid': 0.0, 'high': 0.0}

    # Do the matching
    for mri_loc, mri in enumerate(mri_reversed):
        mri_loc = 29 - mri_loc
        if mri >= cutoff_high:
            spec = 'high'
        elif (mri >= cutoff_low) & (mri < cutoff_high):
            spec = 'mid'
        else:
            spec = 'low'

        # Do the matching
        for tau in [4, 3, 2]:
            n_match[(spec, tau, mri_loc)] = wells_remain[tau][mri_loc]
            # n_available_remain[spec] += -n_match[(spec, tau, mri_loc)]
            wells_remain[tau][mri_loc] += -n_match[(spec, tau, mri_loc)]
            n_match_by_spec[spec] += n_match[(spec, tau, mri_loc)]

    return n_match, n_available_remain, wells_remain, n_match_by_spec


def do_intermediary_matching_with_frictions(state, n_enter_by_tau, mri_reversed,
                                            n_rigs, params):
    """ Idea: do assortative matching taking into account the frictions in each
    submarket. Finds two mri cutoffs where above cutoff_high: assign to the high submarket.
    Args:
        state:
        n_enter_by_tau:
        mri_reversed:
        spec_mri_ordered: which spec has the highest match value
        n_rigs:
        params:

    Returns:

    """
    state_by_spec = copy.deepcopy({'low': state[1], 'mid': state[2], 'high': state[3]})

    # Do the matching
    mri_state_grid = np.linspace(0, 2.15, 30)
    total_value_by_cutoffs = dict()
    prob_match_by_cutoffs = dict()
    for cutoff_low_loc, cutoff_low in enumerate(mri_state_grid):
        for cutoff_high in mri_state_grid[cutoff_low_loc + 1:]:
            (
                n_match,
                n_available_remain,
                wells_remain,
                n_match_by_spec
            ) = do_matching_with_cutoff(
                cutoff_low, cutoff_high, mri_reversed, state, n_enter_by_tau)

            # Convert matches in to a prob. of matching
            total_n_matches_by_spec = dict()
            for spec in ['high', 'mid', 'low']:
                total_n_matches_by_spec[spec] = 0.0
                for tau in [2, 3, 4]:
                    for mri_loc in range(len(mri_reversed)):
                        try:
                            total_n_matches_by_spec[spec] += n_match[
                                (spec, tau, mri_loc)]
                        except:
                            pass

            # Get probability dist. of matches
            prob_match_by_tau_spec = dict()
            for spec in ['high', 'mid', 'low']:
                for tau in [2, 3, 4]:
                    prob_match_by_tau_spec[(tau, spec)] = list()
                    for mri_loc in range(len(mri_reversed)):
                        try:
                            # Get share of allocated matches for a particular bin conditional on a rig type
                            # Idea: intermediary 'suggests' a targeting rule
                            p_match = n_match[(spec, tau, mri_loc)] / total_n_matches_by_spec[spec]

                            # Get resulting prob. of rig matching in submarket after targeting (note: no rejections)
                            theta = state_by_spec[spec] / total_n_matches_by_spec[spec]
                            p_match_submarket = (
                                1 - np.exp(-params[f'a_1_{spec}'] / theta)
                            )
                            if p_match_submarket > 1 / theta:
                                p_match_submarket = 1 / theta

                            # Prob a well is matched with a (tau, spec) combination AND is available to match:
                            # = prob. of a particular well type give rig is matched
                            # * prob. rig is matched give it is available to match
                            # * prob rig is available to match
                            prob_match_by_tau_spec[(tau, spec)].append(
                                p_match * p_match_submarket * (state_by_spec[spec] / n_rigs[spec])
                            )
                        except:
                            # print('adding zero')
                            prob_match_by_tau_spec[(tau, spec)].append(0.0)
                    prob_match_by_tau_spec[(tau, spec)] = np.array(
                        prob_match_by_tau_spec[(tau, spec)]
                    )

            prob_match_by_cutoffs[(cutoff_low, cutoff_high)] = prob_match_by_tau_spec

            # Get the total match value
            # (note: not scaled here by well draws since this is the same for all options)
            total_value = 0
            for spec in ['low', 'mid', 'high']:
                gas_price_portion = (
                    params['m_2'] * state[0] * (
                    params['rho_0']
                    + params['rho_1'] * mri_state_grid
                    + params['rho_2'] * mri_state_grid * mri_state_grid
                    + params['rho_3'] * mri_state_grid * mri_state_grid * mri_state_grid
                    )
                )
                for k, tau in enumerate([2, 3, 4]):
                    output = (
                            (
                                gas_price_portion
                                + params[f'm_1_{spec}'] * mri_state_grid
                                + params[f'm_0_{spec}']
                            )
                            @ prob_match_by_tau_spec[(tau, spec)]
                    ).sum()
                    total_value += output
            total_value_by_cutoffs[(cutoff_low, cutoff_high)] = total_value

    df_total_value_by_cutoffs = pd.Series(total_value_by_cutoffs)
    idxmax = df_total_value_by_cutoffs.idxmax()
    output = prob_match_by_cutoffs[(idxmax[0], idxmax[1])]

    return output


def update_state_detail_intermediary(
        state_detail, n_enter_by_tau, n_rigs, g, params):

    # Put it all together; make extensions exogenous;
    # go down the list until run out of wells
    mri_reversed = np.linspace(2.15, 0, 30)

    # GET EXTENSIONS --------------------------------------------------------------------
    extensions_by_spec_tau = model_objects.get_extensions(
        cutoff_mask_by_spec=None,
        params=params,
        state_detail=state_detail
    )

    # GET STATE -------------------------------------------------------------------------
    state, prob_unemployed, prob_no_extend = model_objects.get_state_simple(
        extensions_by_spec_tau=extensions_by_spec_tau,
        state_detail=state_detail,
        n_rigs=n_rigs,
        g=g,
        verbose=True
    )

    # DO INTERMEDIARY MATCHING ----------------------------------------------------------
    prob_match_by_tau_spec = do_intermediary_matching_with_frictions(
        state=state,
        n_enter_by_tau=n_enter_by_tau,
        mri_reversed=mri_reversed,
        n_rigs=n_rigs,
        params=params
    )

    # FINALIZE MATCHING -----------------------------------------------------------------
    for spec in ['low', 'mid', 'high']:
        for i, tau in enumerate([2, 3, 4]):
            state_detail[spec][i] = np.roll(state_detail[spec][i], -1, axis=1)
            state_detail[spec][i][:, -1] = \
                prob_match_by_tau_spec[(tau, spec)] + extensions_by_spec_tau[spec][tau]

    return state_detail, state, prob_match_by_tau_spec, extensions_by_spec_tau


def do_simulation_intermediary(state_data, n_rigs, n_enter_by_ym_tau, params):
    mri_state_grid_length = 30
    mri_state_grid = np.linspace(0, 2.15, mri_state_grid_length)

    print("DOING MATCHING")
    mri_state_grid_length = 30

    moments_by_state = dict()
    state_detail_by_ym = dict()
    state_simple_by_ym = dict()
    prob_new_match_by_ym = dict()
    prob_extension_by_ym = dict()

    # Set up initial state:
    state_detail_2 = np.zeros((mri_state_grid_length, 2))
    state_detail_3 = np.zeros((mri_state_grid_length, 3))
    state_detail_4 = np.zeros((mri_state_grid_length, 4))

    state_detail = {
        'low': copy.deepcopy([state_detail_2, state_detail_3, state_detail_4]),
        'mid': copy.deepcopy([state_detail_2, state_detail_3, state_detail_4]),
        'high': copy.deepcopy([state_detail_2, state_detail_3, state_detail_4])
    }

    state_simple_old = np.array([0.0, 0.0, 0.0, 0.0])

    # DO THE BURN IN --------------------------------------------------------------------
    row = state_data.loc[

        state_data['date'] == '2000-01-01', ['date', 'g', 'n_l', 'n_m', 'n_h']].iloc[0]
    burn_iter_max = 200

    i = 0
    while i < burn_iter_max:
        (
            state_detail_new,
            state_simple,
            prob_match_by_tau_spec,
            extensions_by_spec_tau
        ) = update_state_detail_intermediary(
            copy.deepcopy(state_detail),
            n_enter_by_ym_tau[row.date],
            n_rigs,
            row.g,
            params
        )

        if np.max(np.abs(state_simple - state_simple_old)) < 0.00001:
            break

        for spec in ['low', 'mid', 'high']:
            for k in [0, 1, 2]:
                state_detail[spec][k] = \
                    0.9 * state_detail[spec][k] + 0.1 * state_detail_new[spec][k]

        state_simple_old = state_simple

        i = i + 1

    # DO THE SIMULATION -----------------------------------------------------------------
    for row in state_data.itertuples():
        (
            state_detail,
            state_simple,
            prob_match_by_tau_spec,
            extensions_by_spec_tau
        ) = update_state_detail_intermediary(
            state_detail,
            n_enter_by_ym_tau[row.date],
            n_rigs,
            row.g,
            params
        )

        moments_by_state[row.date] = simulation.get_moments(
            state_detail,
            mri_state_grid,
            n_rigs
        )

        # Save all results
        state_detail_by_ym[row.date] = copy.deepcopy(state_detail)
        state_simple_by_ym[row.date] = copy.deepcopy(state_simple)
        prob_new_match_by_ym[row.date] = copy.deepcopy(prob_match_by_tau_spec),
        prob_extension_by_ym[row.date] = copy.deepcopy(extensions_by_spec_tau),

    moments_by_state = pd.DataFrame(moments_by_state).T

    return {
        'moments_by_state': moments_by_state,
        'state_detail_by_ym': state_detail_by_ym,
        'state_simple_by_ym': state_simple_by_ym,
        'prob_new_match_by_ym': prob_new_match_by_ym,
        'prob_extension_by_ym': prob_extension_by_ym
    }


def get_av_price_and_p_match(params, state, search_grid_by_spec, search_value_by_spec,
                             first_stage_surplus_by_spec, results, date, n_rigs,
                             mri_max):
    # Update av. price, prob. matching, etc
    mri_by_node = np.linspace(0, mri_max, 30)
    prob_match_by_spec_tau = dict()
    q_by_spec = dict()
    p_by_spec = {
        'low': 0,
        'mid': 0,
        'high': 0
    }
    av_price_by_spec_tau = dict()
    for j, spec in enumerate(['low', 'mid', 'high']):
        prob_match_total = 0.0
        for tau in [2, 3, 4]:
            # Get the prices for each mri node
            # Update params
            params['m_0'] = params[f'm_0_{spec}']
            params['m_1'] = params[f'm_1_{spec}']
            price_by_mri = list()
            for mri in mri_by_node:
                price_by_mri.append(
                    prices.get_price_all(
                        mri=mri,
                        tau=tau,
                        state=state,
                        grid=search_grid_by_spec[spec],
                        values=search_value_by_spec[spec],
                        params=params,
                        seeds=range(1),
                        **first_stage_surplus_by_spec[spec]
                    )[0]
                )
            price_by_mri = np.array(price_by_mri)

            # Get the average price
            av_price_by_spec_tau[(spec, tau)] \
                = (
                    (price_by_mri * results['prob_new_match_by_ym'][date][0][
                        (spec, tau)]).sum()
                    / results['prob_new_match_by_ym'][date][0][(spec, tau)].sum()
            )
            prob_match_by_spec_tau[(spec, tau)] = (
                    results['prob_new_match_by_ym'][date][0][(spec, tau)].sum()
                    / (results['state_simple_by_ym'][date][j + 1] / n_rigs[
                spec])
            )
            prob_match_total += prob_match_by_spec_tau[(spec, tau)]
            p_by_spec[spec] += av_price_by_spec_tau[(spec, tau)] * \
                               prob_match_by_spec_tau[(spec, tau)]

        q_by_spec[spec] = np.array([
            1 - prob_match_total,
            prob_match_by_spec_tau[(spec, 2)],
            prob_match_by_spec_tau[(spec, 3)],
            prob_match_by_spec_tau[(spec, 4)]
        ])
        p_by_spec[spec] = p_by_spec[spec] / prob_match_total

    return p_by_spec, q_by_spec


def do_simulation_smoothing(params, state_data, n_rigs, mri_max,
                            search_grid_by_spec, search_value_by_spec,
                            config, seeds, options):

    date = pd.to_datetime('2000-01-01')
    mean_gas = state_data['g'].mean()
    state_data['g'] = mean_gas
    p_by_spec = {
        'low': 0.04,
        'mid': 0.04,
        'high': 0.04
    }
    q_by_spec = {
        'low': np.array([0.25, 0.25, 0.25, 0.25]),
        'mid': np.array([0.25, 0.25, 0.25, 0.25]),
        'high': np.array([0.25, 0.25, 0.25, 0.25])
    }

    p_by_spec_old = copy.deepcopy(p_by_spec)
    iter = 0
    eps = 1.0
    max_iter = 20
    max_eps = 0.001
    while (eps > max_eps) & (iter < max_iter):
        print(f'Iteration {iter}')
        state = np.array([mean_gas, 10.0, 10.0, 10.0]) # Note: can make n_avail arbitrary
        first_stage_by_spec = dict()
        first_stage_surplus_by_spec = dict()

        # Have some trouble looping over spec so write it out explicitly:
        p_low = copy.deepcopy(p_by_spec['low'])
        q_low = copy.deepcopy(q_by_spec['low'])

        @njit
        def price_new_low(s, tau):
            return p_low

        @njit
        def price_extend_low(s, tau):
            return p_low

        @njit
        def prob_match_low(s):
            return q_low
        
        # mid
        p_mid = copy.deepcopy(p_by_spec['mid'])
        q_mid = copy.deepcopy(q_by_spec['mid'])

        @njit
        def price_new_mid(s, tau):
            return p_mid

        @njit
        def price_extend_mid(s, tau):
            return p_mid

        @njit
        def prob_match_mid(s):
            return q_mid
        
        # high
        p_high = copy.deepcopy(p_by_spec['high'])
        q_high = copy.deepcopy(q_by_spec['high'])

        @njit
        def price_new_high(s, tau):
            return p_high

        @njit
        def price_extend_high(s, tau):
            return p_high

        @njit
        def prob_match_high(s):
            return q_high

        # others:
        eta = params['eta']
        @njit
        def prob_extend(s, p):
            return eta

        # Write out loop explicitly
        first_stage_by_spec['low'] = {
            'price_new': copy.deepcopy(price_new_low),
            'price_extend': copy.deepcopy(price_extend_low),
            'prob_match': copy.deepcopy(prob_match_low),
            'prob_extend': copy.deepcopy(prob_extend),
            'const': state,
            'r': np.zeros((4, 4)),
            'sigma': 0.0
        }
        first_stage_by_spec['mid'] = {
            'price_new': copy.deepcopy(price_new_mid),
            'price_extend': copy.deepcopy(price_extend_mid),
            'prob_match': copy.deepcopy(prob_match_mid),
            'prob_extend': copy.deepcopy(prob_extend),
            'const': state,
            'r': np.zeros((4, 4)),
            'sigma': 0.0
        }
        first_stage_by_spec['high'] = {
            'price_new': copy.deepcopy(price_new_high),
            'price_extend': copy.deepcopy(price_extend_high),
            'prob_match': copy.deepcopy(prob_match_high),
            'prob_extend': copy.deepcopy(prob_extend),
            'const': state,
            'r': np.zeros((4, 4)),
            'sigma': 0.0
        }

        @njit
        def prob_extend_surplus(s, p, mri):
            return eta

        for spec in ['low', 'mid', 'high']:
            first_stage_surplus_by_spec[spec] = {
                'prob_extend': copy.deepcopy(prob_extend_surplus),
                'const': state,
                'r': np.zeros((4, 4)),
                'sigma': 0.0
            }

        # Get the value functions
        value_sim_by_spec = dict()
        for spec in ['low', 'mid', 'high']:
            value_sim_by_spec[spec] = value_search.get_value_at_state(
                state_initial=state,
                first_stage=first_stage_by_spec[spec],
                beta=params['beta'],
                sim_length=1000,
                seeds=seeds,
                fortnight=False
            )
            search_value_by_spec[spec] = \
                search_value_by_spec[spec] * 0.0 + value_sim_by_spec[spec]

        # Get the resulting surplus
        match_values_by_tau_spec = dict()
        for tau, spec in itertools.product([2, 3, 4], ['low', 'mid', 'high']):
            print(f'Getting surplus for: contract length: {tau}; rig type: {spec}')
            (
                surplus_grid,
                nodes_grid,
                nodes_list,
                match_values_by_tau_spec[(tau, spec)]
            ) = surplus.init_fast_surplus(
                tau=tau,
                spec=spec,
                grid=search_grid_by_spec[spec],
                values=search_value_by_spec[spec],
                beta=params['beta'],
                options=options,
                max_sim_length=1000,
                prob_extend=prob_extend_surplus,
                const=state,
                r=np.zeros((4, 4)),
                sigma=0.0,
                seeds=seeds,
                **config
            )

        # Do simulation
        surplus_values_by_tau_spec = surplus.build_fast_surplus(
            match_values_by_tau_spec, params, surplus_grid)[0]

        results = simulation.do_simulation(
            state_data=state_data,
            params=params,
            surplus_grid=surplus_grid,
            surplus_values_by_tau_spec=surplus_values_by_tau_spec,
            n_rigs=n_rigs,
            mri_max=mri_max
        )

        # Retrieve the av price and prob. of matching from the results (later used
        # to feed back into the value function)
        p_by_spec, q_by_spec = get_av_price_and_p_match(
            params, state, search_grid_by_spec, search_value_by_spec,
            first_stage_surplus_by_spec, results, date, n_rigs, mri_max
        )

        state_data['n_l'] = results['state_simple_by_ym'][date][1]
        state_data['n_m'] = results['state_simple_by_ym'][date][2]
        state_data['n_g'] = results['state_simple_by_ym'][date][2]

        # Update iteration parameters
        iter += 1
        eps = np.max([
            np.abs(p_by_spec['low'] - p_by_spec_old['low']),
            np.abs(p_by_spec['mid'] - p_by_spec_old['mid']),
            np.abs(p_by_spec['high'] - p_by_spec_old['high'])
        ])
        p_by_spec_old = copy.deepcopy(p_by_spec)
        print(f'Current av price iter: {p_by_spec}')

    return results, state_data, search_value_by_spec, first_stage_by_spec
